#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Feb 17 11:44:57 2024

@author: jcfq2
"""

import os

jwst_dir='/Users/jcfq2/data/observations/jwst'

os.chdir(jwst_dir)

import matplotlib.pyplot as plt
import matplotlib.colors as colours
from astropy.visualization import make_lupton_rgb
from astropy.table import Table
import ch4_fiddlesticks as ch4
import pandas as pd
import h3ppy
import glob
import spiceypy as spice
from astropy.io import fits
import numpy as np
from importlib import reload
import JWSTSolarSystemPointing as jssp
reload(jssp)
from scipy.interpolate import interp1d

# In[0]

kernel_dir = '/Users/jcfq2/data/observations/jwst/kernels/'
jssp.load_kernels(kdir=kernel_dir)

# Set up a h3ppy object too, always useful
h3p_model = h3ppy.h3p()

# In[1]

fn = 'jw05308001001_03119_g395h-f290lp_s3d.fits'
file_dir = '/Users/jcfq2/data/observations/jwst/5308_dither_combined/'

file = file_dir + fn
# Create a JWSTSolarSystemPointing object that helps with lots of things JWST
geo = jssp.JWSTSolarSystemPointing(file)

# Caluclate the geometry for the full IFU cube. This is a three dimensional arrray.
cube = geo.full_fov()

# Get the wavelength scale for this observation - should be the same for all NIRSpec G395H/F290L observations.
wave = geo.get_wavelength()

# These are the available geometric parameters - if you need something that's not here, let me know!
# Only using Pandas to make this table pretty, generally not a fan
pd.DataFrame(geo.keys, columns=["Parameter"])

# In[1]


colors2 = plt.cm.Greens_r(np.linspace(0, 1, 220))
colors1 = plt.cm.cubehelix_r(np.linspace(0.0, 0.4, 32))

# combine them and build a new colormap
colors = np.vstack((colors1, colors2))

for i in range(32):
    colors[31-i,0]=colors[32,0]*(32-(i+1))/32
    colors[31-i,1]=colors[32,1]*(32-(i+1))/32
    colors[31-i,2]=colors[32,2]*(32-(i+1))/32
# for i in range(32): print(colors[32,1]*(32-(i+1))/32)

mygreen = colours.LinearSegmentedColormap.from_list('my_green', colors)


colors2 = plt.cm.Reds_r(np.linspace(0, 1, 220))
colors1 = plt.cm.cubehelix_r(np.linspace(0.0, 0.4, 32))

# combine them and build a new colormap
colors = np.vstack((colors1, colors2))

for i in range(32):
    colors[31-i,0]=colors[32,0]*(32-(i+1))/32
    colors[31-i,1]=colors[32,1]*(32-(i+1))/32
    colors[31-i,2]=colors[32,2]*(32-(i+1))/32
# for i in range(32): print(colors[32,1]*(32-(i+1))/32)

myred = colours.LinearSegmentedColormap.from_list('my_red', colors)


colors2 = plt.cm.Blues_r(np.linspace(0, 1, 220))
colors1 = plt.cm.cubehelix_r(np.linspace(0.0, 0.4, 32))

# combine them and build a new colormap
colors = np.vstack((colors1, colors2))

for i in range(32):
    colors[31-i,0]=colors[32,0]*(32-(i+1))/32
    colors[31-i,1]=colors[32,1]*(32-(i+1))/32
    colors[31-i,2]=colors[32,2]*(32-(i+1))/32
# for i in range(32): print(colors[32,1]*(32-(i+1))/32)

myblue = colours.LinearSegmentedColormap.from_list('my_blue', colors)



colors2 = plt.cm.Purples_r(np.linspace(0, 1, 220))
colors1 = plt.cm.cubehelix_r(np.linspace(0.0, 0.4, 32))

# combine them and build a new colormap
colors = np.vstack((colors1, colors2))

for i in range(32):
    colors[31-i,0]=colors[32,0]*(32-(i+1))/32
    colors[31-i,1]=colors[32,1]*(32-(i+1))/32
    colors[31-i,2]=colors[32,2]*(32-(i+1))/32
# for i in range(32): print(colors[32,1]*(32-(i+1))/32)

mypurple = colours.LinearSegmentedColormap.from_list('my_purple', colors)




# In[9] plotting figure from movie
from JWSTSolarSystemPointing import get_pixel_polygons as jssp_gpp

from pypolyclip import clip_multi, clip_single
 


# fig = plt.figure(figsize=(15.6,6.08))
fig = plt.figure(figsize=(10,10))


xpos = np.arange(4)/4 # 0.5 for color bar + 1 for end position
xpos[2]=0.58

xsize = 1/4 # 0.5 for color bar + 1 for end position

ypos = np.arange(4)/4
ysize=1/4

# for xp in range(6):
#     for yp in range(3):
        

 

# # Add the second subplot with a custom position
# ax2 = fig.add_axes([0.6, 0.1, 0.3, 0.4])  # [left, bottom, width, height]
# ax2.plot([4, 5, 6], [1, 2, 3])
# ax2.set_title('Subplot 2')

# Show the plot




files = sorted(glob.glob(
    "/Users/jcfq2/data/observations/jwst/5308_dither_separated/*.fits"))

# files=files[10]

lens=len(files)
# lens=35
z0=0
zz=0
x_offset=5
y_offset=-3
xxx=0
starter=1

savedir="/Users/jcfq2/OneDrive - Northumbria University - Production Azure AD/python/jwst/saturn/movies/"



i=4*4+2
    
    # file = 'data/jw03665032001_02101_g395h-f290lp_s3d.fits'
geo = jssp.JWSTSolarSystemPointing(files[i])
wave = geo.get_wavelength()
cube = geo.full_fov()
spec = geo.convert(wave, geo.im[:, 25, 25])
cube = geo.full_fov()
cube_1 = geo.full_fov(corner=1)
cube_2 = geo.full_fov(corner=2)
cube_3 = geo.full_fov(corner=3)
cube_4 = geo.full_fov(corner=4)

# if i == 0:
#     print(geo.obs_start)


sy=4
ey=-3
geoim=geo.im[:, :,sy:ey]

wavemin = 3.0
wavemax = 3.2

# dlambda=0.00015
#whw = np.argwhere((wave > 3.3529+dlambda) & (wave < 3.3535+dlambda)).flatten()
whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()
img_reflect = np.rot90(np.nanmedian(geoim[whw, :, :], axis=0),3)

img_methane = geoim[2775, :, :]


aaa=np.rot90(np.nanmedian(geoim[672:677,:,:],axis=0),3)
bbb=np.rot90(np.nanmedian(geoim[672+7:677+7,:,:],axis=0),3)
ccc=np.rot90(np.nanmedian(geoim[672+22:677+22,:,:],axis=0),3)
# ddd=np.nanmedian(geoim[whw_331,:,:],axis=0)

img_methane = bbb-aaa



wavemin = 3.9529
wavemax = 3.9535

whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()
img_h3p = np.nanmedian(geoim[whw, :, :], axis=0)

ab=np.rot90(geoim[2239,:,:]+geoim[1568,:,:]+geoim[1642,:,:]+geoim[1010,:,:]+geoim[1011,:,:]+geoim[1018,:,:]+geoim[1019,:,:],3)
cd=np.rot90((geoim[2241,:,:]*0.8*0.5+geoim[2236,:,:]*1.19*0.5)+((geoim[1571,:,:]+geoim[1565,:,:])*0.5)+(0.0768*geoim[1640,:,:])+(geoim[1004,:,:]+geoim[1005,:,:]*2),3)
ef =geoim[2239,:,:]
# print(wave[[1010,1011,1018,1019,1568,1642,2239]])
# print(wave[[1004,1005,1571,1565,1640  ,2236 ,2241]])
# print(    [1   ,2   ,0.5 ,0.5 ,0.0768,0.595,0.4 ])

img_h3p = ab-cd

img_heat = geoim[3487,:,:]
wavemin = 5.0
wavemax = 5.2

# dlambda=0.00015
#whw = np.argwhere((wave > 3.3529+dlambda) & (wave < 3.3535+dlambda)).flatten()
whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()
img_heat = np.rot90(np.nanmedian(geoim[whw, :, :], axis=0),3)

#ra, dec = geo.get_delta_ra_dec_arcsec()

# im_h3p[im_h3p > np.nanmedian(im_h3p)*3] = np.nanmedian(im_h3p)*3
# im_h3p[im_h3p<0] = 0
    

wavemin_F323 = 3.148
wavemax_F323 = 3.38
whw_F323 = np.argwhere((wave > wavemin_F323) & (wave < wavemax_F323)).flatten()


wavemin_F323f = 3.148
wavemax_F323f = 3.35
whw_F323f = np.argwhere((wave > wavemin_F323f) & (wave < wavemax_F323f)).flatten()


# could be problematic - lots going on in the background 
file_F323='/Users/jcfq2/data/observations/jwst/saturn/PSG_files/psg_rad_F323.txt'
bck_F323 = np.loadtxt(file_F323)
f_F323 = interp1d(bck_F323[:,0], bck_F323[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
bck_F323_whw = f_F323(wave[whw_F323])

F323_values=np.loadtxt('F323N_filter.txt',delimiter=' ')

f_F323f = interp1d(F323_values[:,0], F323_values[:,1], kind='linear')  # You can change 'linear' to other interpolation methods like 'cubic'
filter_F323_sub = f_F323f(wave[whw_F323f])


filter_F323 = np.zeros_like(wave)
filter_F323[whw_F323f]=filter_F323_sub    
filter_F323_h3p = filter_F323*1.0
filter_F323_without_h3p = filter_F323*1.0


bck_F323_whw=filter_F323[whw_F323]*bck_F323_whw

h3p = h3ppy.h3p()
h3p.set(T=500, N=6e14, wave=wave, R=2000)
h3p_model=h3p.model()
h3p_model=h3p_model/np.max(h3p_model)

filter_F323_h3p[h3p_model < 0.001]=0
filter_F323_without_h3p[h3p_model > 0.001]=0

    

F323_h3p_image = np.zeros_like(ef)
noF323_h3p_image = np.zeros_like(ef)
F323_full_image = np.zeros_like(ef)

for xx in range(ef[:,0].size):
    for yy in range(ef[0,:].size):
        pixel_F323=np.nansum(geoim[:,xx,yy]*filter_F323)
        pixel_F323_h3p=np.nansum(geoim[:,xx,yy]*filter_F323_h3p)
        pixel_F323_noh3p=np.nansum(geoim[:,xx,yy]*filter_F323_without_h3p)
        
        F323_h3p_image[xx,yy] = pixel_F323_h3p
        noF323_h3p_image[xx,yy] = pixel_F323_noh3p
        F323_full_image[xx,yy] = pixel_F323

F323_full_image=np.rot90(F323_full_image,3)    
    
bbox_props1= dict(boxstyle="circle,pad=0.15", fc="whitesmoke", ec="silver", lw=2)

    
ax1a = fig.add_axes([xpos[0], 0.75, xsize,0.25])  # [left, bottom, width, height]
grey=ax1a.imshow(img_reflect,aspect=1,origin='lower',cmap='Greys_r',vmin=0)
ax1a.plot([20,20,21,21,20],[20,21,21,20,20],c='r')

# ax1.plot([1, 2, 3], [4, 5, 6])
# ax1.set_title('Subplot ',str(xp),str(yp))
# ax1.set_xlim((1*xscale,8*xscale))
# ax1.set_xlim((0.9*yscale,5.5*yscale))
ax1a.set_axis_off()
ax1a.text(3,3,'a', ha="center", va="center", weight='bold',bbox=bbox_props1,size='x-large')



ax1b = fig.add_axes([xpos[1], 0.75, xsize,0.25])  # [left, bottom, width, height]

# ax1.plot([1, 2, 3], [4, 5, 6])
# ax1.set_title('Subplot ',str(xp),str(yp))
red=ax1b.imshow(img_heat,aspect=1,origin='lower',cmap=myred,vmin=0)
ax1b.plot([20,20,21,21,20],[20,21,21,20,20],c='r')

# ax1b.set_xlim((1*xscale,8*xscale))
# ax1b.set_xlim((0.9*yscale,5.5*yscale))
ax1b.set_axis_off()
ax1b.text(3,3,'b', ha="center", va="center", weight='bold',bbox=bbox_props1,size='x-large')

# image_g[image_g<350]=350
ax2a = fig.add_axes([xpos[0], 0.5, xsize,0.25])  # [left, bottom, width, height]
# ax1.plot([1, 2, 3], [4, 5, 6])
# ax1.set_title('Subplot ',str(xp),str(yp))
# green=ax2.imshow(image_g[yystart:yyend,xxstart:xxend],aspect=1,origin='lower',cmap=mygreen,norm=colours.LogNorm(vmin=np.nanmedian(image_g[yyend-8:yyend,xxstart:xxend]), vmax=image_g.max()))
# green=ax2.imshow(image_g[yystart:yyend,xxstart:xxend],aspect=1,origin='lower',cmap=mygreen,norm=colours.LogNorm(vmin=350, vmax=17000))
blue=ax2a.imshow(bbb,aspect=1,origin='lower',cmap=myblue,vmin=0,vmax=np.nanmax(bbb))
ax2a.plot([20,20,21,21,20],[20,21,21,20,20],c='r')

# ax2a.set_xlim((1*xscale,8*xscale))
# ax2a.set_xlim((0.9*yscale,5.5*yscale))
ax2a.set_axis_off()
ax2a.text(3,3,'c', ha="center", va="center", weight='bold',bbox=bbox_props1,size='x-large')

ax2b = fig.add_axes([xpos[1], 0.5, xsize,0.25])  # [left, bottom, width, height]
# ax1.plot([1, 2, 3], [4, 5, 6])
# ax1.set_title('Subplot ',str(xp),str(yp))
# green=ax2.imshow(image_g[yystart:yyend,xxstart:xxend],aspect=1,origin='lower',cmap=mygreen,norm=colours.LogNorm(vmin=np.nanmedian(image_g[yyend-8:yyend,xxstart:xxend]), vmax=image_g.max()))
# green=ax2.imshow(image_g[yystart:yyend,xxstart:xxend],aspect=1,origin='lower',cmap=mygreen,norm=colours.LogNorm(vmin=350, vmax=17000))
purple1=ax2b.imshow(aaa,aspect=1,origin='lower',cmap=mypurple,vmin=0,vmax=np.nanmax(bbb)/5)
# ax2a.set_xlim((1*xscale,8*xscale))
# ax2a.set_xlim((0.9*yscale,5.5*yscale))
ax2b.set_axis_off()
ax2b.text(28,28,'x5',c='w',size='xx-large')
ax2b.text(3,3,'d', ha="center", va="center", weight='bold',bbox=bbox_props1,size='x-large')
ax2b.plot([20,20,21,21,20],[20,21,21,20,20],c='r')


ax3a = fig.add_axes([xpos[0], 0.25, xsize,0.25])  # [left, bottom, width, height]

green=ax3a.imshow(ab,aspect=1,origin='lower',cmap=mygreen,vmin=0,vmax=np.nanmax(ab))
# ax3a.set_xlim((1*xscale,8*xscale))
# ax3a.set_xlim((0.9*yscale,5.5*yscale))
ax3a.set_axis_off()
ax3a.text(3,3,'e', ha="center", va="center", weight='bold',bbox=bbox_props1,size='x-large')
ax3a.plot([20,20,21,21,20],[20,21,21,20,20],c='r')

ax3b = fig.add_axes([xpos[1], 0.25, xsize,0.25])  # [left, bottom, width, height]

purple2=ax3b.imshow(cd,aspect=1,origin='lower',cmap=mypurple,vmin=0,vmax=np.nanmax(ab)/2)
# ax3a.set_xlim((1*xscale,8*xscale))
# ax3a.set_xlim((0.9*yscale,5.5*yscale))
ax3b.set_axis_off()
ax3b.text(28,28,'x2',c='w',size='xx-large')
ax3b.text(3,3,'f', ha="center", va="center", weight='bold',bbox=bbox_props1,size='x-large')
ax3b.plot([20,20,21,21,20],[20,21,21,20,20],c='r')

# ax1.plot([1, 2, 3], [4, 5, 6])
# ax1.set_title('Subplot ',str(xp),str(yp))


ax4 = fig.add_axes([xpos[0], 0.0, xsize,0.25])  # [left, bottom, width, height]
# ax1.plot([1, 2, 3], [4, 5, 6])
# ax1.set_title('Subplot ',str(xp),str(yp))
# grey=ax4.imshow(image_y[yystart:yyend,xxstart:xxend],aspect=1,origin='lower',cmap='Greys_r',vmin=np.nanmedian(image_y[yyend-5:yyend,xxstart:xxend]))
# ax4.set_xlim((1*xscale,8*xscale))
# ax4.set_xlim((0.9*yscale,5.5*yscale))
# ax4.set_axis_off()
ax4.text(3,3,'g', ha="center", va="center", weight='bold',bbox=bbox_props1,size='x-large')
ax4.plot([20,20,21,21,20],[20,21,21,20,20],c='r')
# ax4 = fig.add_axes([xpos[1], 0.0, xsize,0.25])  # [left, bottom, width, height]

copper=ax4.imshow(F323_full_image,origin='lower',cmap='copper',vmin=0)
# ax3a.set_xlim((1*xscale,8*xscale))
# ax3a.set_xlim((0.9*yscale,5.5*yscale))
ax4.set_axis_off()

# ax1.plot([1, 2, 3], [4, 5, 6])
# ax1.set_title('Subplot ',str(xp),str(yp))



 
    # if plot:

    # _plot(px, py, xc, yc, area, slices)



spec = geo.convert(wave,geoim[:,20,20])
# cheats in the earlier fits
wavemin_methane = 3.27
wavemax_methane = 3.47
whw_methane = np.argwhere((wave > wavemin_methane) & (wave < wavemax_methane)).flatten()

wavemin_h3p = 3.5
wavemax_h3p = 4.4
whw_h3p = np.argwhere((wave > wavemin_h3p) & (wave < wavemax_h3p)).flatten()

wavemin_methane = 2.8
wavemax_methane = 5.2
whw_heat = np.argwhere((wave > wavemin_methane) & (wave < wavemax_methane)).flatten()

spec_F323=spec[whw_F323]*filter_F323[whw_F323]
spec_F323_pre=spec[whw_F323]/10
spec_F323_pre[spec_F323_pre>np.max(spec_F323)*1.4]=np.max(spec_F323)*1.4

# cbar=fig.colorbar(red,ax=ax_cb_0,location='right',aspect=4,pad=5.6)
# cbar.ax.set_ylabel('Heat')

xscale=3.4

xscale_h3p=1.6/1.8*(xscale/2.73)


ax_cb_1 = fig.add_axes([xpos[2]+0.01, 0.012, xsize*xscale,1/4*0.9]) 
# ax_cb_1.plot(np.arange(30))
# ax_cb.set_axis_off()
ax_cb_1.plot(wave[whw_F323],spec_F323*1e3,linewidth=0.5,c='k',drawstyle='steps-mid')
ax_cb_1.plot(wave[whw_F323],spec_F323_pre*1e3,linewidth=0.5,c='darkorange',drawstyle='steps-mid')
ax_cb_1.plot(wave[whw_F323],filter_F323[whw_F323]/np.max(filter_F323[whw_F323])*np.max(spec_F323)*1e3*1.4,linewidth=0.5,c='r',drawstyle='steps-mid')
# ax_cb_1.set_yticks([])
ax_cb_1.fill(
    [wave[whw_methane][25],wave[whw_methane][150],wave[whw_methane][150],wave[whw_methane][25]], 
    [np.nanmax(spec_F323)*1e3*1.4,np.nanmax(spec_F323)*1e3*1.4,0,0],
    facecolor='none', edgecolor='dodgerblue', linewidth=1,alpha=0.8)
ax_cb_1.set_ylabel("W/m$^2$/sr/$\mu$m [x 1000]")
ax_cb_1.set_xlabel("Micron")
ax_cb_1.text(3.23,0.010,'x10')
# ax_cb_2 = fig.add_axes([xpos[6]+0.01, 0.24, xsize*xscale,1/7]) 



# ax_cb_2.plot(wave[whw_h3p],spec[whw_h3p],linewidth=0.5,c='k')
# ax_cb_2.set_ylim([0,0.0003])
# ax_cb_2.set_yticks([])

# ax_cb_2.plot([wave[2239],wave[2239]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.5)

# ax_cb_2.plot([wave[1568],wave[1568]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.5)
# ax_cb_2.plot([wave[1642],wave[1642]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.5)

# ax_cb_2.plot([wave[1010],wave[1010]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.5)
# ax_cb_2.plot([wave[1011],wave[1011]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.5)
# ax_cb_2.plot([wave[1018],wave[1018]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.5)
# ax_cb_2.plot([wave[1019],wave[1019]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.5)
#     # ab=geoim[2239,:,:]+geoim[1568,:,:]+geoim[1642,:,:]+geoim[1010,:,:]+geoim[1011,:,:]+geoim[1018,:,:]+geoim[1019,:,:]
# ax_cb_2.plot([wave[2241],wave[2241]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='grey',alpha=0.5)
# ax_cb_2.plot([wave[2236],wave[2236]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='grey',alpha=0.5)

# ax_cb_2.plot([wave[1571],wave[1571]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='grey',alpha=0.5)
# ax_cb_2.plot([wave[1565],wave[1565]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='grey',alpha=0.5)
# ax_cb_2.plot([wave[1640],wave[1640]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='grey',alpha=0.5)

# ax_cb_2.plot([wave[1004],wave[1004]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='grey',alpha=0.5)
# ax_cb_2.plot([wave[1005],wave[1005]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='grey',alpha=0.5)

ax_cb_2 = fig.add_axes([xpos[2]+0.01, 0.25+0.015, xsize*xscale_h3p,1/4*0.9]) 

ax_cb_2.plot(wave[whw_h3p][30:70],spec[whw_h3p][30:70]*1e3,linewidth=0.5,c='k',drawstyle='steps-mid')
ax_cb_2.set_ylim([0,0.00012*1e3])
# ax_cb_2.set_yticks([])
ax_cb_2.set_ylabel("W/m$^2$/sr/$\mu$m [x 1000]")

ax_cb_2.plot([wave[1004],wave[1004]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='plum',alpha=0.5,linewidth=4)
ax_cb_2.plot([wave[1005],wave[1005]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='plum',alpha=0.5,linewidth=4)

ax_cb_2.plot([wave[1010],wave[1010]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.35,linewidth=4)
ax_cb_2.plot([wave[1011],wave[1011]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.35,linewidth=4)
ax_cb_2.plot([wave[1018],wave[1018]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.35,linewidth=4)
ax_cb_2.plot([wave[1019],wave[1019]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.35,linewidth=4)

ax_cb_2b = fig.add_axes([xpos[2]+0.02+xsize*xscale_h3p, 0.25+0.015, xsize*xscale_h3p,1/4*0.9]) 

ax_cb_2b.plot(wave[whw_h3p][600:690],spec[whw_h3p][600:690]*1e3,linewidth=0.5,c='k',drawstyle='steps-mid')
ax_cb_2b.set_ylim([0,0.0001*1e3])
ax_cb_2b.set_yticks([])


ax_cb_2b.plot([wave[1568],wave[1568]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.35,linewidth=3)
ax_cb_2b.plot([wave[1642],wave[1642]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.35,linewidth=3)

    # ab=geoim[2239,:,:]+geoim[1568,:,:]+geoim[1642,:,:]+geoim[1010,:,:]+geoim[1011,:,:]+geoim[1018,:,:]+geoim[1019,:,:]

ax_cb_2b.plot([wave[1571],wave[1571]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='plum',alpha=0.5,linewidth=3)
ax_cb_2b.plot([wave[1565],wave[1565]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='plum',alpha=0.5,linewidth=3)
ax_cb_2b.plot([wave[1640],wave[1640]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='plum',alpha=0.5,linewidth=3)


ax_cb_2c = fig.add_axes([xpos[2]+0.03+xsize*xscale_h3p*2, 0.25+0.015, xsize*xscale_h3p,1/4*0.9]) 

ax_cb_2c.plot(wave[whw_h3p][1260:1300],spec[whw_h3p][1260:1300]*1e3,linewidth=0.5,c='k',drawstyle='steps-mid')
ax_cb_2c.set_ylim([0,0.00012*1e3])
ax_cb_2c.set_yticks([])


ax_cb_2c.plot([wave[2239],wave[2239]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='green',alpha=0.35,linewidth=6)

ax_cb_2c.plot([wave[2241],wave[2241]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='plum',alpha=0.5,linewidth=6)
ax_cb_2c.plot([wave[2236],wave[2236]],[0,np.nanmax(spec[whw_h3p]*1e3)],c='plum',alpha=0.5,linewidth=6)
# ax_cb_2c.fill([wave[2236],wave[2241],wave[2241],wave[2236]], [np.nanmax(spec[whw_heat]),np.nanmax(spec[whw_heat]),0,0], facecolor='lightpink', edgecolor='plum', linewidth=0.2,alpha=0.8)

    # cd=(geoim[2241,:,:]*0.8*0.5+geoim[2236,:,:]*1.19*0.5)+((geoim[1571,:,:]+geoim[1565,:,:])*0.5)+(0.0768*geoim[1640,:,:])+(geoim[1004,:,:]+geoim[1005,:,:]*2)

# ax_cb.set_axis_off()


reload(ch4)

ch4fit = ch4.fit_non_LTE_CH4_JWST(wave)
fit = ch4fit.fit(wave, spec)
# fit[whw_null]=0.
# fit[whw_null2]=0.


fit_sub = fit[whw_methane]
wave_sub=wave[whw_methane]



ax_cb_3 = fig.add_axes([xpos[2]+0.01, 0.5+0.015, xsize*xscale,1/4*0.9]) 
# ax_cb_3.plot(np.arange(30))
# ax_cb.set_axis_off()


ax_cb_3.plot(wave_sub[25:150],spec[whw_methane][25:150]*1e3,linewidth=0.5,c='k',drawstyle='steps-mid')
ax_cb_3.set_ylim(0,np.max(spec[whw_methane][25:150]*1e3*1.1))
ax_cb_3.set_ylabel("W/m$^2$/sr/$\mu$m [x 1000]")
# ax_cb_3.plot(wave_sub,ch4fit.ch4_fun[whw_methane]*ch4fit.fun_scaling+0.000025)
# ax_cb_3.plot(wave_sub,ch4fit.ch4_hot[whw_methane]*ch4fit.hot_scaling+0.000025)
# ax_cb_3.set_yticks([])
# ax_cb_3.fill([wave[672+7],wave[677+7],wave[677+7],wave[672+7]], [np.nanmax(spec[whw_methane]),np.nanmax(spec[whw_methane]),0,0], facecolor='lightblue', edgecolor='blue', linewidth=0.2,alpha=0.5)
ax_cb_3.fill([wave[672+7],wave[677+7],wave[677+7],wave[672+7]], [np.nanmax(spec[whw_methane]*1e6),np.nanmax(spec[whw_methane]*1e6),0,0], facecolor='lightblue', edgecolor='blue', linewidth=0.2,alpha=0.8)
ax_cb_3.fill([wave[672],wave[677],wave[677],wave[672]], [np.nanmax(spec[whw_methane]*1e6),np.nanmax(spec[whw_methane]*1e6),0,0], facecolor='plum', edgecolor='violet', linewidth=0.2,alpha=0.8)



ax_cb_4 = fig.add_axes([xpos[2]+0.01, 0.75+0.015, xsize*xscale,1/4*0.9]) 
# ax_cb_4.plot(np.arange(30))
# ax_cb.set_axis_off()
ax_cb_4.plot(wave[whw_heat], spec[whw_heat],linewidth=0.2,c='k',drawstyle='steps-mid')
# ax_cb_4.set_yticks([])
ax_cb_4.fill([3.0,3.2,3.2,3.0], [np.nanmax(spec[whw_heat]),np.nanmax(spec[whw_heat]),0,0], facecolor='grey', edgecolor='darkgrey', linewidth=0.2,alpha=0.8)
ax_cb_4.fill([5.0,5.2,5.2,5.0], [np.nanmax(spec[whw_heat]),np.nanmax(spec[whw_heat]),0,0], facecolor='lightsalmon', edgecolor='orangered', linewidth=0.2,alpha=0.8)
ax_cb_4.set_ylabel("W/m$^2$/sr/$\mu$m")

ax_cb_4.fill(
    [wave[whw_h3p][30],wave[whw_h3p][70],wave[whw_h3p][70],wave[whw_h3p][30]], 
    [np.nanmax(spec[whw_heat]),np.nanmax(spec[whw_heat]),0,0],
    facecolor='none', edgecolor='green', linewidth=1,alpha=0.8)
ax_cb_4.fill(
    [wave[whw_h3p][600],wave[whw_h3p][690],wave[whw_h3p][690],wave[whw_h3p][600]], 
    [np.nanmax(spec[whw_heat]),np.nanmax(spec[whw_heat]),0,0],
    facecolor='none', edgecolor='green', linewidth=1,alpha=0.8)
ax_cb_4.fill(
    [wave[whw_h3p][1260],wave[whw_h3p][1300],wave[whw_h3p][1300],wave[whw_h3p][1260]], 
    [np.nanmax(spec[whw_heat]),np.nanmax(spec[whw_heat]),0,0],
    facecolor='none', edgecolor='green', linewidth=1,alpha=0.8)

ax_cb_4.fill(
    [wave[whw_methane][25],wave[whw_methane][150],wave[whw_methane][150],wave[whw_methane][25]], 
    [np.nanmax(spec[whw_heat]),np.nanmax(spec[whw_heat]),0,0],
    facecolor='none', edgecolor='dodgerblue', linewidth=1,alpha=0.8)
ax_cb_4.fill(
    [wave[whw_F323][0],wave[whw_F323][-1],wave[whw_F323][-1],wave[whw_F323][0]], 
    [np.nanmax(spec[whw_heat]),np.nanmax(spec[whw_heat]),0,0],
    facecolor='none', edgecolor='orange', linewidth=1,alpha=0.8)

# 
# cbar=fig.colorbar(blue,ax=ax_cb_3,location='right',aspect=4)
# cbar.ax.set_ylabel('Methane')

# cbar=fig.colorbar(green,ax=ax_cb_2,location='right',aspect=4)
# cbar.ax.set_ylabel('H3+')

# cbar=fig.colorbar(red,ax=ax_cb_1,location='right',aspect=4)
# cbar.ax.set_ylabel('Heat')

# cbar=fig.colorbar(ax1,ax=ax1,location='right', ticks=[0.4,0.6,0.8,1.0,1.2])
#                     cbar.ax.set_ylabel('Methane')


fig.savefig('asd_fig1b.pdf', dpi=300, bbox_inches='tight', facecolor='white', pad_inches=0) 

plt.show()

